%% Generate RCT2Scale grouping 1

% Random seed
seed = 101;
rng(seed)

% Read estimates for nudge unit RCTs
opts = detectImportOptions('../Data/NudgeUnits.xlsx');
T = readtable('../Data/NudgeUnits.xlsx', opts);

% Drop all published RCTs
T(T.publication == 1, :) = [];

% Unique policy areas
uniquePolicyareas = unique(T.policyarea);

% Binary variables list
binaryVars = {
    'mech_choicedesign'
    'mech_framing'
    'mech_personalmotivation'
    'mech_planning'
    'mech_simplification'
    'mech_socialcues'
    'medium_email'
    'medium_inperson'
    'medium_other'
    'medium_physicalletter'
    'medium_postcard'
    'medium_website'
};


clusters = struct();

for i = 1:length(uniquePolicyareas)
    currentArea = uniquePolicyareas{i};
    
    % Convert to a valid structure field name (if it has spaces/punctuation)
    validFieldName = matlab.lang.makeValidName(currentArea);
    
    % Extract subset for this policy area
    subset = T(strcmp(T.policyarea, currentArea), :);
    n = size(subset, 1);
    
    if n < 3
        % If fewer than 3 rows, skip or store empty
        clusters.(validFieldName) = {};
        continue;
    end
    
    % Drop leftover if the total isn't a multiple of 3
    r = mod(n, 3);
    if r > 0
        % Drop the last r observations
        subset(end-r+1:end, :) = [];
        n = size(subset, 1);  % new count
    end
    
    if n == 0
        % If we dropped everything, skip
        clusters.(validFieldName) = {};
        continue;
    end
    
    % Construct adjacency matrix
    % When using Euclidean distance => adjacency = 1/(1+distance)
    M = subset{:, binaryVars};   % numeric array of size (n x 12)
    distVec = pdist(M, 'euclidean');     % condensed distance
    distMat = squareform(distVec);
    adjMat = 1 ./ (1 + distMat); % simple distance->similarity mapping

    % Spectral clustering
    k = n / 3;
    
    % Normalized Laplacian
    rowSumVec = sum(adjMat, 2);
    Dtilde = diag(rowSumVec);
    DinvSqrt = Dtilde^(-1/2);
    DinvSqrt(isinf(DinvSqrt)) = 0;   % handle divide-by-zero
    L = DinvSqrt * adjMat * DinvSqrt;
    
    % Get eigenvectors
    [V, E] = eig(L, 'vector');
    % Sort by eigenvalue descending
    [~, idx] = sort(E, 'descend');
    V = V(:, idx);
    
    % Take top k eigenvectors
    U = V(:, 1:k);
    % Normalize rows of U
    U = bsxfun(@rdivide, U, sqrt(sum(U.^2, 2)) + eps);
    
    % K-means on these row vectors
    clusterIdx = kmeans(U, k, 'MaxIter', 1000, 'Replicates', 10);
    
    %  Post-process to enforce EXACT groups of 3
    %  Because k-means (and spectral) do not enforce uniform cluster size.
    %  We'll shuffle around membership so each cluster ends with size 3.
    %  Tally cluster membership counts
    clSize = accumarray(clusterIdx, 1, [k,1], @sum, 0);
    
    while any(clSize > 3) || any(clSize < 3)
        % Find one cluster with more than 3
        idxBig = find(clSize > 3, 1, 'first');
        % Find one cluster with fewer than 3
        idxSmall = find(clSize < 3, 1, 'first');
        
        if ~isempty(idxBig) && ~isempty(idxSmall)
            % Move one observation from the big cluster to the small cluster
            bigObservations = find(clusterIdx == idxBig);
            % pick a random observation from the big cluster
            obsToMove = bigObservations(randi(numel(bigObservations)));
            
            % reassign that observation to the smaller cluster
            clusterIdx(obsToMove) = idxSmall;
            
            % update clSize
            clSize = accumarray(clusterIdx, 1, [k,1], @sum, 0);
        else
            % We can't fix it further, so break out (should be rare)
            break;
        end
    end
    
    % Store the resulting group IDs
    clusters.(validFieldName) = cell(k, 1);
    for c = 1:k
        clusters.(validFieldName){c} = subset.id(clusterIdx == c);
    end
end

% Extract Estimates and SE from Clusters
EstimatesMatrix = [];
SEMatrix = [];
SampleSizeMatrix  = [];

policyAreas = fieldnames(clusters);

for i = 1:length(policyAreas)
    currentArea = policyAreas{i};
    groupList = clusters.(currentArea);  % Get the cell array of groups
    
    for g = 1:length(groupList)
        theseIDs = groupList{g};  % Get the IDs in this group
        
        % Extract treatmenteffect and SE values
        estimates = zeros(1,3);
        ses = zeros(1,3);
        samplesizes = zeros(1,3);

        for j = 1:3
            rowIndex = find(T.id == theseIDs(j));
            if isempty(rowIndex)
                continue; % Just in case an ID is missing
            end
            estimates(j) = T.treatmenteffect(rowIndex);
            ses(j) = T.SE(rowIndex);
            samplesizes(j) = T.trialN(rowIndex);
        end
        
        % Append to matrices
        EstimatesMatrix = [EstimatesMatrix; estimates];
        SEMatrix = [SEMatrix; ses];
        SampleSizeMatrix = [SampleSizeMatrix; samplesizes];
    end
end

% Save matrices to Excel
EstimatesTable = array2table(EstimatesMatrix, 'VariableNames', {'Study1', 'Study2', 'Study3'});
SETable = array2table(SEMatrix, 'VariableNames', {'Study1', 'Study2', 'Study3'});
SampleSizeTable  = array2table(SampleSizeMatrix,'VariableNames', {'Study1', 'Study2', 'Study3'});

writetable(EstimatesTable, '../Data/RCT2Scale_data.xlsx', 'Sheet', 'Estimates');
writetable(SETable, '../Data/RCT2Scale_data.xlsx', 'Sheet', 'SE');
writetable(SampleSizeTable, '../Data/RCT2Scale_data.xlsx', 'Sheet', 'Sample Size');
